{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b4cee136",
   "metadata": {},
   "source": [
    "## Preparation\n",
    "\n",
    "### Install Dependencies\n",
    "\n",
    "First we need to install dependencies to support operator training and inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d0b9ef7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python -m pip install torch torchvision torchaudio torchmetrics==0.7.0 towhee towhee.models>=0.8.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deb511ec",
   "metadata": {},
   "source": [
    "### Download dataset\n",
    "This op is trained on the [FMA dataset](https://github.com/mdeff/fma). We need to fine-tune on the [gtzan dataset](https://www.kaggle.com/datasets/andradaolteanu/gtzan-dataset-music-genre-classification). In addition to downloading the gtzan dataset, we also need to download several datasets to add noise during training. They are [Microphone impulse response dataset](http://micirp.blogspot.com/), [Aachen Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/aachen-impulse-response-database/), and [AudioSet](https://research.google.com/audioset/download.html). These datasets are all publicly available, please contact us if there are any copyright issues.\n",
    "\n",
    "We have followed this [guide](https://github.com/stdio2016/pfann#prepare-dataset) for data preprocessing. All you need to do is directly download these processed data and information.\n",
    "\n",
    "Your can create a folder to store all the downloaded data, and it needs about 4G space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "142e3b5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "dataset_path = './dataset'\n",
    "if not os.path.exists(dataset_path):\n",
    "    os.mkdir(dataset_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a344485d",
   "metadata": {},
   "source": [
    "#### Download gtzan dataset\n",
    "The [gtzan dataset](https://www.kaggle.com/datasets/andradaolteanu/gtzan-dataset-music-genre-classification) contains 1000 tracks of 30 second length. There are 10 genres, each containing 100 tracks which are all 22050Hz Mono 16-bit audio files in .wav format. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e9f68250",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100 1168M  100 1168M    0     0  4129k      0  0:04:49  0:04:49 --:--:-- 4694k\n"
     ]
    }
   ],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/gtzan_full.zip -o ./dataset/genres_original.zip\n",
    "! unzip -q -o ./dataset/genres_original.zip -d ./dataset\n",
    "! rm ./dataset/genres_original.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51d6933b",
   "metadata": {},
   "source": [
    "#### Download Microphone impulse response dataset\n",
    "[Microphone impulse response dataset](http://micirp.blogspot.com/) contains the specially recorded microphone impulse response data, which can be used to adding noise during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4c0f31b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100  152k  100  152k    0     0  83217      0  0:00:01  0:00:01 --:--:-- 1456k\n"
     ]
    }
   ],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/micirp.zip -o ./dataset/micirp.zip\n",
    "! unzip -q -o ./dataset/micirp.zip -d ./dataset\n",
    "! rm ./dataset/micirp.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "756b33d5",
   "metadata": {},
   "source": [
    "#### Download Aachen Impulse Response Database\n",
    "[Aachen Impulse Response Database](https://www.iks.rwth-aachen.de/en/research/tools-downloads/databases/aachen-impulse-response-database/) is a set of impulse responses that were measured in a wide variety of rooms. It can be used for adding noise during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0ac55bd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100  193M  100  193M    0     0  3854k      0  0:00:51  0:00:51 --:--:-- 4098k\n"
     ]
    }
   ],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/AIR_1_4.zip -o ./dataset/AIR_1_4.zip\n",
    "! unzip -q -o ./dataset/AIR_1_4.zip -d ./dataset\n",
    "! rm ./dataset/AIR_1_4.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8680aea",
   "metadata": {},
   "source": [
    "#### Download AudioSet\n",
    "[AudioSet](https://research.google.com/audioset/download.html) is an audio event dataset, which consists of over 2M human-annotated 10-second video clips. We also use it for adding noise during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "15ca989c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100 1500M  100 1500M    0     0  3133k      0  0:08:10  0:08:10 --:--:-- 4729k8:41  0:06:51  0:01:50 2320k\n",
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100  582M  100  582M    0     0  2620k      0  0:03:47  0:03:47 --:--:-- 2613k\n"
     ]
    }
   ],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/audioset_p1 -o ./dataset/audioset_p1\n",
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/audioset_p2 -o ./dataset/audioset_p2\n",
    "! cat ./dataset/audioset_p* > ./dataset/audioset.zip\n",
    "! unzip -q -o ./dataset/audioset.zip -d ./dataset\n",
    "! rm ./dataset/audioset_p* ./dataset/audioset.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ab69b3f",
   "metadata": {},
   "source": [
    "#### Download data information for training\n",
    "This is some data information that has been preprocessed for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3f6aa31d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100 25286  100 25286    0     0  14237      0  0:00:01  0:00:01 --:--:--  355k\n"
     ]
    }
   ],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/gtzan_info.zip -o ./gtzan_info.zip\n",
    "! unzip -q -o ./gtzan_info.zip\n",
    "! rm ./gtzan_info.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cebb5f4c",
   "metadata": {},
   "source": [
    "## Fine-tune Audio Embedding with Neural Network Fingerprint operator"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5e5f3fd",
   "metadata": {},
   "source": [
    "### Instantiate operator\n",
    "We can instantiate a towhee nnfp operator. This audio embedding operator converts an input audio into a dense vector which can be used to represent the audio clip's semantics. Each vector represents for an audio clip with a fixed length of around 1s. This operator generates audio embeddings with fingerprinting method introduced by Neural Audio Fingerprint. The nnfp operator is suitable for audio fingerprinting. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "84376759",
   "metadata": {},
   "outputs": [],
   "source": [
    "import towhee\n",
    "nnfp_op = towhee.ops.audio_embedding.nnfp().get_op()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f2fa8a5d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[-0.15469249, -0.02260398, -0.05088959, ...,  0.14650534,\n",
       "          0.04951884, -0.04235527],\n",
       "        [-0.00608123, -0.06859994, -0.0750239 , ...,  0.0840608 ,\n",
       "          0.12196919, -0.1123263 ],\n",
       "        [-0.18665867,  0.08474724,  0.03795987, ...,  0.06031123,\n",
       "         -0.09239668, -0.08622654],\n",
       "        ...,\n",
       "        [ 0.02841254,  0.01915257,  0.02964114, ...,  0.04307787,\n",
       "         -0.08863434,  0.0016751 ],\n",
       "        [-0.0166699 ,  0.08893833,  0.05510458, ...,  0.13624884,\n",
       "          0.03493905, -0.13401009],\n",
       "        [-0.04592355, -0.07944845,  0.09267115, ...,  0.02575601,\n",
       "         -0.09419111,  0.03918429]], dtype=float32),\n",
       " (10, 128))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_audio = 'dataset/audioset/B4lyT64WFjc_0.wav'\n",
    "embedding = nnfp_op(test_audio)\n",
    "embedding, embedding.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3524531f",
   "metadata": {},
   "source": [
    "### Start training\n",
    "When initialized, this operator already contains the model with weights trained on the FMA data. The goal of our training is to fine-tune it on another audio dataset domain to better fit the new data distribution. \n",
    "\n",
    "We can first look at the default training configuration. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "41671a97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\r\n",
      "  \"train_csv\": \"gtzan_info/lists/gtzan_train.csv\",\r\n",
      "  \"validate_csv\": \"gtzan_info/lists/gtzan_valtest.csv\",\r\n",
      "  \"test_csv\": \"gtzan_info/lists/gtzan_valtest.csv\",\r\n",
      "  \"music_dir\": \"dataset/genres_original\",\r\n",
      "  \"model_dir\": \"fma_test\",\r\n",
      "  \"cache_dir\": \"caches\",\r\n",
      "  \"batch_size\": 640,\r\n",
      "  \"shuffle_size\": null,\r\n",
      "  \"fftconv_n\": 32768,\r\n",
      "  \"sample_rate\": 8000,\r\n",
      "  \"stft_n\": 1024,\r\n",
      "  \"stft_hop\": 256,\r\n",
      "  \"n_mels\": 256,\r\n",
      "  \"f_min\": 300,\r\n",
      "  \"f_max\": 4000,\r\n",
      "  \"segment_size\": 1,\r\n",
      "  \"hop_size\": 0.5,\r\n",
      "  \"time_offset\": 1.2,\r\n",
      "  \"pad_start\": 0,\r\n",
      "  \"epoch\": 1,\r\n",
      "  \"lr\": 1e-4,\r\n",
      "  \"tau\": 0.05,\r\n",
      "  \"noise\": {\r\n",
      "    \"train\": \"gtzan_info/lists/noise_train.csv\",\r\n",
      "    \"validate\": \"gtzan_info/lists/noise_val.csv\",\r\n",
      "    \"dir\": \"dataset/audioset\",\r\n",
      "    \"snr_max\": 10,\r\n",
      "    \"snr_min\": 0\r\n",
      "  },\r\n",
      "  \"micirp\": {\r\n",
      "    \"train\": \"gtzan_info/lists/micirp_train.csv\",\r\n",
      "    \"validate\": \"gtzan_info/lists/micirp_val.csv\",\r\n",
      "    \"dir\": \"dataset/micirp\",\r\n",
      "    \"length\": 0.5\r\n",
      "  },\r\n",
      "  \"air\": {\r\n",
      "    \"train\": \"gtzan_info/lists/air_train.csv\",\r\n",
      "    \"validate\": \"gtzan_info/lists/air_val.csv\",\r\n",
      "    \"dir\": \"dataset/AIR_1_4\",\r\n",
      "    \"length\": 1\r\n",
      "  },\r\n",
      "  \"cutout_min\": 0.1,\r\n",
      "  \"cutout_max\": 0.5,\r\n",
      "  \"model\": {\r\n",
      "    \"d\": 128,\r\n",
      "    \"h\": 1024,\r\n",
      "    \"u\": 32,\r\n",
      "    \"fuller\": true,\r\n",
      "    \"conv_activation\": \"ReLU\"\r\n",
      "  },\r\n",
      "  \"indexer\": {\r\n",
      "    \"index_factory\": \"IVF200,PQ64x8np\",\r\n",
      "    \"top_k\": 100,\r\n",
      "    \"frame_shift_mul\": 1\r\n",
      "  }\r\n",
      "}\r\n"
     ]
    }
   ],
   "source": [
    "! cat ./gtzan_info/default_gtzan.json"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "473e06e1",
   "metadata": {},
   "source": [
    "This json contains some training configurations such as epoch, batch size, etc., as well as some data and model information.  \n",
    "There are some csv file paths in this json, which contain the audio data information of the corresponding data set.  \n",
    "We only need to pass this file path to the `train()` interface to train this operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "64cc63e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading noise dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1193/1193 [00:15<00:00, 74.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([95161077])\n",
      "loading Aachen IR dataset\n",
      "loading microphone IR dataset\n",
      "load cached music from caches/1gtzan_train.bin\n",
      "training data contains 47200 samples\n",
      "loading noise dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 299/299 [00:03<00:00, 75.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([23878784])\n",
      "loading Aachen IR dataset\n",
      "loading microphone IR dataset\n",
      "load cached music from caches/1gtzan_valtest.bin\n",
      "evaluate before fine-tune...\n",
      "\n",
      "validate score: 0.805910\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-01-13 11:42:46,713 - 140227902603840 - trainer.py-trainer:319 - WARNING: TrainingConfig(output_dir='fine_tune_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=640, val_batch_size=-1, seed=42, epoch_num=1, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=0.0001, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard={'log_dir': None, 'comment': ''}, loss='CrossEntropyLoss', optimizer='custom_', lr_scheduler_type='cosine', warmup_ratio=0.0, warmup_steps=0, device_str=None, sync_bn=False, freeze_bn=False)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f1ddfa67061446394cfe1003b553e5e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/148 [00:00<?, ?step/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "validate score: 0.813064\n"
     ]
    }
   ],
   "source": [
    "nnfp_op.train(config_json_path='./gtzan_info/default_gtzan.json')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d00b421",
   "metadata": {},
   "source": [
    "We only conducted one epoch fine-tuning because too many epochs will lead to overfitting. It can be seen that on the new dataset distribution, the loss has decreased and the validate score has improved after the fine-tuning.  \n",
    "For more training details you can refer to [this script](https://towhee.io/audio-embedding/nnfp/src/branch/main/train_nnfp.py). If you want to fine-tune with your own data or manner, you can refer to it for modification."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a86a0fe7",
   "metadata": {},
   "source": [
    "### Use your fine-tuned weights\n",
    "If you need to use your trained weights to extract embedding, you just reload the trained model weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b9305bc7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[-0.16501123, -0.01201352, -0.0383786 , ...,  0.12018757,\n",
       "          0.04253541, -0.01914779],\n",
       "        [ 0.00619547, -0.07101642, -0.06639319, ...,  0.06269935,\n",
       "          0.12877102, -0.06712442],\n",
       "        [-0.18207397,  0.09906715,  0.02120742, ...,  0.03872132,\n",
       "         -0.07399878, -0.07401986],\n",
       "        ...,\n",
       "        [ 0.03872605,  0.01387926,  0.04619657, ..., -0.02789425,\n",
       "         -0.06140808,  0.03822998],\n",
       "        [ 0.00434632,  0.08006225,  0.05249849, ...,  0.0985365 ,\n",
       "          0.03861852, -0.10067634],\n",
       "        [-0.039698  , -0.07496819,  0.09917679, ..., -0.02591049,\n",
       "         -0.06788268,  0.0777601 ]], dtype=float32),\n",
       " (10, 128))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_nnfp_op = towhee.ops.audio_embedding.nnfp(model_path='./fine_tune_output/final_epoch/model.pth').get_op()\n",
    "embedding = new_nnfp_op(test_audio)\n",
    "embedding, embedding.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b8b2c68",
   "metadata": {},
   "source": [
    "We can observe that there are some differences between the newly inference output after training and the previous, but not too much. This also shows that we indeed just only fine-tune the model, instead of drastically changing the weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7608623f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nnfp",
   "language": "python",
   "name": "nnfp"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}